import torch



def get_device(target_gpu = 1):
    target_gpu = 1 if target_gpu is None else target_gpu

    # 检查 GPU 数量
    if torch.cuda.device_count() >= (target_gpu+1):
        # 设置默认使用第 1 个 GPU（索引为 0）
        torch.cuda.set_device(target_gpu)
        print(f"已设置默认 GPU 索引为 {target_gpu} 块")
    else:
        print(f"GPU 数量不足 {(target_gpu+1)} 块，无法使用索引为 {target_gpu} 的 GPU")


    device = torch.device(f"cuda:{target_gpu}" if torch.cuda.device_count() > target_gpu else "cpu")
    return device